# -*- coding: utf-8 -*-
"""
Created on Fri May 12 14:34:52 2017

@author: Y.Gou

example of running this code for validation is: 
validate_classes(out_fldr = 'c:\\validation\\',inRaster = 'c:\\scotland_classification_map.tif',shpdir = 'c:\\validation_shps\\scot\\')
"""

from osgeo import gdal
import os, errno
import numpy as np
import csv


def getallinfo(g):
    """Returns information about a gdal supported dataset.
    Args:
    g (gdal supported file object): this is returned from gdal.Open(filename) where filename is the name of a gdal supported dataset

    Returns:
    list: [x_min, pixel_width, rotation, y_max, rotation, pixel_height, rows, cols, bands, x_max, y_min]

    """
    (x_min, pixel_width, rotation, y_max, rotation, pixel_height) = g.GetGeoTransform()
    rows = g.RasterYSize
    cols = g.RasterXSize
    bands = g.RasterCount
    x_max = (cols*pixel_width) + x_min
    y_min = y_max + (rows*pixel_height)
    return x_min, pixel_width, rotation, y_max, rotation, pixel_height, rows, cols, bands, x_max, y_min

def clipShp(input_raster,shpfile_path,output_fldr):
    '''
    This code clips the shapefile using the extent of the raster
    
    Args:
    input_raster: the s0data file generated by gdal.Open() commond, the extent of which is used to clip the shapefile
    shpfile_path: the dir of shapefile that will be clipped
    output_fldr: location of the clipped shapefile to be stored
    
    Returns:
    a clipped shapefile with the same extent as the raster
    '''
    (x_min, pixel_width, rotation, y_max, rotation, pixel_height, rows, cols, bands, x_max, y_min) = getallinfo(input_raster) 
  
 #   print getallinfo(input_raster)
    clipfile = os.path.join(output_fldr, 'outline_clip.shp' )
    remove_exist(clipfile)
    print 'The clipped shapefile to the extent of the raster, resultant shp is saved in ' + clipfile                                    
    os.system('ogr2ogr -clipdst ' + str(x_min) + ' ' + str(y_min) + ' ' + str(x_max) + ' ' + str(y_max) + ' ' + clipfile + ' ' + shpfile_path)   
    return clipfile

def remove_exist(filename):
    '''
    Delteing a file if it exists (this is to clean any temproate file generated from previous processes)
    '''
    try:
        os.remove(filename)
    except OSError as e: # this would be "except OSError, e:" before Python 2.6
        if e.errno != errno.ENOENT: # errno.ENOENT = no such file or directory
            raise # re-raise exception if a different error occurred

def cal_ratio(alldata, classnumber):
    '''
    This return a list of the number and ratio of each class 
    
    Args:
    alldata is a list containing all classnumbers 
    classnumber is a number or class name
    '''
    results = []
    results.append(classnumber)    
    newclass = [num for num in alldata if np.logical_or(num == classnumber, classnumber in num)]
    results.append(len(newclass))
    results.append(round(len(newclass)/len(alldata),6))
    
    return results

def class_type(inputname):
    '''
    this check class type writen in number or brievations 
    '''
    if np.logical_or('forest' in inputname, '1' in inputname ):
        classname = 'Forest'
    elif np.logical_or('gra' in inputname, '2' in inputname):
        classname = 'Grassland'        
    elif np.logical_or('ara' in inputname, '3' in inputname):
        classname = 'Cropland'
    elif np.logical_or('settlement' in inputname, '4' in inputname):
        classname = 'Settlement'   
    elif np.logical_or('Wetland' in inputname, inputname == 5):
        classname = 'Wetland'
    elif np.logical_or('Other' in inputname, inputname == 6):
        classname = 'Other'
    elif np.logical_or('water' in inputname, '7' in inputname):
        classname = 'Water' 
    elif np.logical_or('nodata' in inputname, '0' in inputname ):
        classname = 'Nodata'
    else:
        print 'cannot find matching names'
        classname = inputname 
    return classname 


def validate_classes(inRaster,shpdir, field_name = 'GRID_CODE',out_fldr):
    '''
    Args:
    out_fldr: It's a dirctory that the .csv with all the validatoin results will be exported to. 
    inRaster: The classification map to be validated, data format as .tif
    shpdir  : It's a dirctory where all the validation points are stored. For each class, the validation points are stores as a point shapefile seperately. 
              The lable for each class that will reported in the resultant validatoin .csv is defined in the 'class_type' function above
    field_name: The name of the field in each validation points that indicts the classes of that validation points. It should be an int and different between classes.  
    
    Reuturn:
    a csv with a suffix of validation.csv containing validation matrix for each class. Reported statistics are: 'validate class','predicted class','number', and 'ratio' 
    
    Example
    validate_classes(out_fldr = 'c:\\validation\\',inRaster = 'c:\\scotland_classification_map.tif',shpdir = 'c:\\validation_shps\\scot\\')
    '''
    ###############################################################################
    # read in the validation shapefile and the classification map to be validated 
    ###############################################################################
    shpfile_path_list = [os.path.join(shpdir, f)
                    for dirpath, dirnames, files in os.walk(shpdir)
                    for f in files if f.endswith('.shp')]     
    s0data = gdal.Open(inRaster)
    
    ###########################################
    # Run Validation 
    ###########################################
    #generate a csv ends with 'validation.csv'
    with open(os.path.join(out_fldr, os.path.basename(inRaster)[:-4] + '_'  + "validation.csv"),'wb') as fs:
        writer = csv.writer(fs)            
        writer.writerow(['validate class','predicted class','number', 'ratio'])
    
        for shpfile_path in shpfile_path_list:
            #loog through each validation class (as a seperate shapefile of points)
            inshp = shpfile_path
            
            #checking which class it is
            validshpname = class_type(os.path.basename(inshp)[:20])
            print inshp
            print '~validating ... ' + validshpname
            
            # clip shp to the extent of the raster
            clipfile = clipShp(s0data,  shpfile_path, out_fldr)
        
            #rasterise the shp 
            print 'rasterise the shapefile'    
            (x_min, pixel_width, rotation, y_max, rotation, pixel_height, rows, cols, bands, x_max, y_min) = getallinfo(s0data)# get info from the raster which will be used to rasterise the shapefile  
            
            clipd_shp_rst = os.path.join(out_fldr, 'groundata_raster.tif') 
            remove_exist(clipd_shp_rst)    #check if a rasterised shapefile exists from previous processes, deleted it if so, run the rasterise commond if not  
            
            os.system('gdal_rasterize -a_nodata 0 -a ' + field_name +  ' -ot Float32 -l ' + os.path.basename(clipfile[:-4]) + ' -te ' + str(x_min) + ' ' + str(y_min) + ' ' + str(x_max) + ' ' + str(y_max) + ' -tR ' + str(pixel_width) + ' ' + str(-pixel_height) + ' ' + clipfile + ' ' + clipd_shp_rst)                    
            r = gdal.Open(clipd_shp_rst) 
            shp_array1 = r.GetRasterBand(1).ReadAsArray()  
         
            shp_array1[shp_array1>0]=1 #now the shp only have 0 and 1, 1 are where validating points are
            bs_array = s0data.GetRasterBand(1).ReadAsArray()#.astype(np.float)    
            new = shp_array1*bs_array
            
            #calculate stats          
            for i in np.unique(new):
                if i == 0:
                    continue #skip 
                else:
                    stat = []
                    stat.append(validshpname)
                    testclass =  class_type(str(i))
                    stat.append(testclass)
                    num = len(new[new==i])
                    stat.append(num)
                    total_nonnum = len(new[new!=0])
                    ratio = round(float(num)/float(total_nonnum),6)
                    stat.append(ratio)
                    writer.writerow(stat)
                    print stat
            writer.writerow([])
                
    fs.close()



